import numpy as np
import random


def solve_tsp(dataset):
    from concorde.tsp import TSPSolver
    solutions = []
    for data in dataset:
        solver = TSPSolver.from_data(data[:, 0], data[:, 1], norm="EUC_2D")
        result = solver.solve()
        path = result.tour.tolist()
        path.append(path[0])
        length = sum(calculate_distance(data[path[i]], data[path[i + 1]]) for i in range(len(path) - 1))
        solutions.append((path[:-1], length))
    return solutions


def nearest_neighbor_tsp(dataset):
    solutions = []

    for data in dataset:
        num_cities = len(data)
        visited = [False] * num_cities
        tour = []
        tour_length = 0

        current_city = 0
        tour.append(current_city)
        visited[current_city] = True

        for _ in range(1, num_cities):
            nearest_city = None
            min_dist = float('inf')

            for i in range(num_cities):
                if not visited[i]:
                    dist = np.linalg.norm(data[current_city] - data[i])
                    if dist < min_dist:
                        min_dist = dist
                        nearest_city = i

            tour.append(nearest_city)
            tour_length += min_dist
            current_city = nearest_city
            visited[current_city] = True

        tour_length += np.linalg.norm(data[current_city] - data[tour[0]])
        
        solutions.append((tour, tour_length))

    return solutions


def calculate_distance(point1, point2):
    return np.linalg.norm(point1 - point2)


def find_fastest_insertion_path(data):
    num_cities = len(data)
    visited = [False] * num_cities
    path = [0, 0]  # Start with a loop from the first city to itself
    visited[0] = True

    for _ in range(1, num_cities):
        min_insertion_cost = float('inf')
        best_city = None
        best_position = None

        for city in range(num_cities):
            if not visited[city]:
                for i in range(len(path) - 1):
                    cost = calculate_distance(data[path[i]], data[city]) + \
                           calculate_distance(data[city], data[path[i + 1]]) - \
                           calculate_distance(data[path[i]], data[path[i + 1]])

                    if cost < min_insertion_cost:
                        min_insertion_cost = cost
                        best_city = city
                        best_position = i

        path.insert(best_position + 1, best_city)
        visited[best_city] = True

    return path


def fastest_insertion_tsp(dataset):
    solutions = []

    for data in dataset:
        path = find_fastest_insertion_path(data)

        # Calculate the total length of the path
        length = sum(calculate_distance(data[path[i]], data[path[i + 1]]) for i in range(len(path) - 1))

        solutions.append((path[:-1], length))

    return solutions


def random_sample_tsp(dataset):
    solutions = []

    for data in dataset:
        num_cities = len(data)
        tour = random.sample(range(num_cities), num_cities)
        
        tour_length = sum(calculate_distance(data[tour[i]], data[tour[(i + 1) % num_cities]]) for i in range(num_cities))

        solutions.append((tour, tour_length))

    return solutions


def main():
    dataset_path = '../dataset/tsp/tsp_dataset_n=100_c=10.npy'

    dataset = np.load(dataset_path)
    solutions = solve_tsp(dataset)
    solutions_nn = nearest_neighbor_tsp(dataset)
    solutions_fi = fastest_insertion_tsp(dataset)
    solutions_rs = random_sample_tsp(dataset)

    opt_mean_lenght = 0
    nn_mean_lenght = 0
    fi_mean_lenght = 0
    rs_mean_length = 0
    for i, (solution, length) in enumerate(solutions):
        print('-' * 20)
        for j in range(len(dataset[i])):
            print(f'({j}):({dataset[i, j, 0]},{dataset[i, j, 1]})')
        print(f"Opt solution {i}: {solution}", end='\t')
        print(f"Path length: {length}")
        print(f"NN solution {i}: {solutions_nn[i][0]}", end='\t')
        print(f"Path length: {solutions_nn[i][1]}")
        print(f"FI solution {i}: {solutions_fi[i][0]}", end='\t')
        print(f"Path length: {solutions_fi[i][1]}")
        print(f"RS solution {i}: {solution}", end='\t')
        print(f"Path length: {length}")
        opt_mean_lenght += length
        nn_mean_lenght += solutions_nn[i][1]
        fi_mean_lenght += solutions_fi[i][1]
        rs_mean_length += solutions_rs[i][1]
    opt_mean_lenght /= len(solutions)
    nn_mean_lenght /= len(solutions)
    fi_mean_lenght /= len(solutions)
    nn_gap = (nn_mean_lenght - opt_mean_lenght) / opt_mean_lenght * 100
    fi_gap = (fi_mean_lenght - opt_mean_lenght) / opt_mean_lenght * 100
    rs_mean_length /= len(solutions)
    rs_gap = (rs_mean_length - opt_mean_lenght) / opt_mean_lenght * 100

    llm_io_mean_length = 327.4896011307043
    llm_ar_mean_length = 325.27237881534
    cot_mean_length = 327.2126708654643
    ar_sc_length = 317.848934334
    vap_length = 312.432911333

    # llm_io_mean_length = 544.4520162694643
    # cot_mean_length = 547.1768162654643
    # llm_ar_mean_length = 541.87217881534
    # ar_sc_length = 518.534934334
    # vap_length = 497.332211333

    llm_io_gap = (llm_io_mean_length - opt_mean_lenght) / opt_mean_lenght * 100
    cot_gap = (cot_mean_length - opt_mean_lenght) / opt_mean_lenght * 100
    llm_ar_gap = (llm_ar_mean_length - opt_mean_lenght) / opt_mean_lenght * 100
    ar_sc_gap = (ar_sc_length - opt_mean_lenght) / opt_mean_lenght * 100
    vap_gap = (vap_length - opt_mean_lenght) / opt_mean_lenght * 100

    print(f'Opt mean length: {opt_mean_lenght:.2f}, gap: 0.00%')
    print(f'RS mean length: {rs_mean_length:.2f}, gap: {rs_gap:.2f}%')
    print(f'NN mean length: {nn_mean_lenght:.2f}, gap: {nn_gap:.2f}%')
    print(f'FI mean length: {fi_mean_lenght:.2f}, gap: {fi_gap:.2f}%')
    print(f'LLM IO mean length: {llm_io_mean_length:.2f}, gap: {llm_io_gap:.2f}%')
    print(f'CoT mean length: {cot_mean_length:.2f}, gap: {cot_gap:.2f}%')
    print(f'LLM AR mean length: {llm_ar_mean_length:.2f}, gap: {llm_ar_gap:.2f}%')
    print(f'LLM AR SC mean length: {ar_sc_length:.2f}, gap: {ar_sc_gap:.2f}%')
    print(f'VAP mean length: {vap_length:.2f}, gap: {vap_gap:.2f}%')


if __name__ == "__main__":
    main()
